library("reticulate")
## Warning: package 'reticulate' was built under R version 4.0.5
library("knitr")
library("Hmisc")
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
## 
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:base':
## 
##     format.pval, units
library("DescTools")
## 
## Attaching package: 'DescTools'
## The following objects are masked from 'package:Hmisc':
## 
##     %nin%, Label, Mean, Quantile
library("stringr")
library("egg")
## Loading required package: gridExtra
library("tidyverse")
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ tibble  3.1.6     ✓ purrr   0.3.4
## ✓ tidyr   1.1.4     ✓ dplyr   1.0.7
## ✓ readr   2.1.1     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::combine()   masks gridExtra::combine()
## x dplyr::filter()    masks stats::filter()
## x dplyr::lag()       masks stats::lag()
## x dplyr::src()       masks Hmisc::src()
## x dplyr::summarize() masks Hmisc::summarize()
# set plotting theme 
theme_set(theme_classic() + 
    theme(text = element_text(size = 24)))

# knitr chunk display options     
opts_chunk$set(comment = "",
               results = "hold",
               fig.show = "hold")

# suppress summarise() grouping warning 
options(dplyr.summarise.inform = F)

Load Data

use_condaenv("plinko")
pd = import("pandas")
df_data = pd$read_pickle("../../data/human_data/full_dataset_vision_corrected.xz")

Segment judgment, rt, and eye-data

df_data_judge = df_data %>% 
  select(participant, trial, response) %>% 
  unique()

df_data_rt = df_data %>% 
  group_by(participant, trial) %>% 
  summarise(rt = tail(t, n=1) - head(t, n=1)) %>% 
  mutate(log_rt = log(rt))

Compute Judgment Means

df_data_mean_judge_train = df_data_judge %>% 
  filter(participant %in% seq(1,15)) %>% 
  group_by(trial) %>% 
  summarise(hole1 = sum(response == 1)/n(),
            hole2 = sum(response == 2)/n(),
            hole3 = sum(response == 3)/n()) %>% 
  pivot_longer(c(hole1, hole2, hole3), 
               names_to = "hole",
               values_to = "human_mean")

Compute RT Means

df_data_mean_rt_train = df_data_rt %>% 
  filter(participant %in% seq(1,15)) %>% 
  mutate(log_rt = ifelse(rt != 0, log(rt), 0)) %>% 
  group_by(trial) %>% 
  summarise(mean_rt = mean(rt),
            mean_log_rt = mean(log_rt))

Results Visualization

Bandit Model

df_model_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/bandit_runs_30_threshold_0.95_tradeoff_0.003_sample_weight_950_bw_30.0_noise_params_0.2_0.8_0.2_trial_0_150.csv") %>% select(-X)

Judgments

df_model_mean_judge = df_model_judge_rt %>% 
  mutate(judgment = judgment + 1,
         judgment=factor(judgment)) %>%
  group_by(trial, judgment) %>% 
  summarise(model_mean = n()/(max(run)+1)) %>% 
  ungroup() %>% 
  complete(trial, judgment,
           fill = list(model_mean=0))

df_data_mean_judge_full = df_data_judge %>% 
  mutate(hole1 = as.numeric(response == 1),
         hole2 = as.numeric(response == 2),
         hole3 = as.numeric(response == 3)) %>% 
  select(-response) %>% 
  pivot_longer(c(hole1, hole2, hole3),
               names_to = "hole",
               values_to = "response") %>% 
  mutate(response = response) %>% 
  group_by(trial, hole) %>% 
  do(data.frame(rbind(smean.cl.boot(.$response)))) %>% 
  rename(human_mean = Mean,
         lower = Lower,
         upper = Upper)

df_human_mean_judge = df_data_mean_judge_full %>%
  mutate(hole = as.factor(str_sub(hole, -1, -1))) %>%
  rename(judgment = hole)

df_to_show = left_join(df_model_mean_judge,
                       df_human_mean_judge, 
                       by=c("trial", "judgment")) %>% 
  mutate(model = "Bandit") 

scaled_model = lm(human_mean ~ model_mean,
                  data = df_to_show)

scaled_model_predictions = predict(scaled_model)

model_cor = round(cor(scaled_model_predictions, df_to_show$human_mean), digits=2)
model_rmse = round(RMSE(scaled_model_predictions, df_to_show$human_mean), digits=2)

ggplot(data = df_to_show, mapping = aes(x = model_mean,
                                        y=human_mean)) +  
  geom_abline(slope = 1,
              intercept = 0,
              linetype="dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.2) +
  geom_point(alpha=0.3) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Prediction") +
  ylab("Participant Response Proportion") +
  annotate("text",
           label = paste("r: ", model_cor),
           x=0.0,
           y=1,
           hjust=0) +
  annotate("text",
           label = paste("rmse: ", model_rmse),
           x=0.0,
           y=0.95,
           hjust = 0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=16),
        axis.text = element_text(size=10))



# ggsave("figures/bandit_judgment.jpg", height = 4, width = 5)

temp = df_to_show$model_mean
sum((temp == 0) | (temp == 1))
[1] 358
# Number of responses on the end
temp = df_model_mean_judge$model_mean 
((temp == 1) ) %>% sum()
[1] 107

Response Time

df_model_mean_rt = df_model_judge_rt %>% 
  mutate(log_cols = ifelse(num_cols != 0, log(num_cols), num_cols)) %>% 
  group_by(trial) %>% 
  summarise(mean_time = mean(num_cols),
            mean_log_time = mean(log_cols))

df_data_mean_rt = df_data_rt %>%
  group_by(trial) %>% 
  summarise(mean_rt = mean(rt),
            mean_log_rt = mean(log(rt)))

df_to_show = left_join(df_model_mean_rt,
                       df_data_mean_rt,
                       by = c("trial"))

model_cor = round(cor(df_to_show$mean_time, df_to_show$mean_rt), digits=2)
model_rmse = round(RMSE(df_to_show$mean_time, df_to_show$mean_rt), digits=2)

ggplot(data = df_to_show, mapping = aes(x = mean_time, y = mean_rt)) +
  geom_point(alpha = 0.7,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  # geom_label(mapping = aes(label = trial)) +
  ggtitle("Bandit Response Time") +
  xlab("Model Mean Collisions Across Runs") +
  ylab("Participant Mean Log Response Time") +
  annotate("text",
           label = paste("r =", model_cor),
           x=12,
           y=2500) +
  annotate("text",
           label = paste("rmse =", model_rmse),
           x=12,
           y=2200) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=14),
        axis.text = element_text(size=12))

# ggsave("figures/bandit_rt.png", height = 4, width = 5)

df_data_mean_rt = df_data_rt %>% 
  group_by(trial) %>% 
  do(data.frame(rbind(smean.cl.boot(.$log_rt)))) %>% 
  rename(mean_log_rt = Mean,
         upper = Upper,
         lower = Lower)

df_to_show = left_join(df_model_mean_rt,
                       df_data_mean_rt,
                       by = c("trial")) %>% 
  mutate(model = "Sequential Sampler")


# df_to_show = df_to_show %>% 
#   filter(mean_log_time > 0.5)


scaled_model = lm(mean_log_rt ~ 1 + mean_log_time,
                  data = df_to_show)

scaled_model_predictions = predict(scaled_model)

model_cor = round(cor(scaled_model_predictions, df_to_show$mean_log_rt), digits=2)
model_rmse = round(RMSE(scaled_model_predictions, df_to_show$mean_log_rt), digits=2)

ggplot(data = df_to_show, mapping = aes(x = mean_log_time, y = mean_log_rt)) +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.15) +
  geom_point(alpha = 0.7,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Mean Log Collisions") +
  ylab("Mean Log Response Time") +
  annotate("text",
           label = paste("r: ", model_cor),
           size = 6,
           x=0,
           y=8.8,
           hjust=0) +
  annotate("text",
           label = paste("rmse: ", model_rmse),
           size =6,
           x=0,
           y=8.68,
           hjust=0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=20),
        axis.text = element_text(size=16),
        plot.margin = margin(10, 1, 1, 10))

ggsave("figures/bandit_log_rt.pdf", height = 4, width = 5)

df_bandit_rt = df_to_show %>% 
  rename(time_measure = mean_time,
         log_time = mean_log_time)

Fixed Sample Model

df_fixed_sample_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/fixed_sample_num_samples_40_bw_50.0_noise_params_0.2_0.8_0.2_trial_0_150.csv") %>%  select(-X)

Judgments

df_fixed_sample_long = df_fixed_sample_judge_rt %>% 
  select(trial, hole1, hole2, hole3) %>% 
  pivot_longer(c(hole1, hole2, hole3),
               names_to = "hole",
               values_to = "prediction") 
  # mutate(judgment = factor(judgment))
df_to_show = df_fixed_sample_long %>% 
  left_join(df_data_mean_judge_full, by = c("trial", "hole")) %>% 
  mutate(model = "Fixed Sample")

fixed_sample_cor = round(cor(df_to_show$prediction, df_to_show$human_mean), digits = 2)
fixed_sample_rmse = round(RMSE(df_to_show$prediction, df_to_show$human_mean), digits = 2)

ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
  geom_abline(slope = 1,
              intercept = 0,
              linetype = "dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha=0.2) +
  geom_point(alpha=0.5,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  annotate("text",
           label = paste("r:", fixed_sample_cor),
           x = 0.0,
           y = 1,
           hjust = 0) +
  annotate("text",
           label = paste("rmse:", fixed_sample_rmse),
           x = 0.0,
           y = 0.95,
           hjust = 0) +
  facet_grid(~ model) + 
  xlab("Model Prediction") +
  ylab("Participant Mean Judgment") +
  theme(plot.title = element_text(size=20, 
                                  hjust=0.5),
        axis.title = element_text(size=16),
        axis.text = element_text(size=10))

# ggsave("figures/fixed_sample_judgments.jpg", height=4, width=5)

temp = df_to_show$prediction %>% round(digits = 6)
sum((temp == 0) | (temp == 1))
[1] 187
df_fixed_judge = df_to_show %>% 
  mutate(judgment = as.factor(str_sub(hole, -1, -1)),
         model = "Uniform Sampler") %>% 
  select(-hole)

Response Time

df_to_show = df_fixed_sample_judge_rt %>% 
  select(trial, num_cols) %>% 
  mutate(time_measure = num_cols,
         log_time = log(num_cols)) %>% 
  left_join(df_data_mean_rt, by = "trial") %>% 
  mutate(model = "Uniform Sampler")


scaled_model = lm(mean_log_rt ~ 1 + log_time,
                  data = df_to_show)

scaled_model_predictions = predict(scaled_model)

fixed_sample_rt_cor = round(cor(scaled_model_predictions, df_to_show$mean_log_rt), digits = 2)
fixed_sample_rt_rmse = round(RMSE(scaled_model_predictions, df_to_show$mean_log_rt), digits = 2)

ggplot(data = df_to_show, mapping = aes(x = log_time, y = mean_log_rt)) +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.15) +
  geom_point(alpha=0.7,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Log Collisions") +
  ylab("Mean log Response Time") +
  annotate("text",
           label = paste("r:", fixed_sample_rt_cor),
           x=5.1,
           y=8.8,
           size=6,
           hjust=0) +
  annotate("text",
           label = paste("rmse:", fixed_sample_rt_rmse),
           x=5.1,
           y=8.68,
           size=6,
           hjust=0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=20),
        axis.text = element_text(size=16),
        plot.margin = margin(10,1,1,10))

ggsave("figures/fixed_sample_rt.pdf", height=4, width = 5)

Cogsci Figures

Judgments

df_bandit_judge = df_model_mean_judge %>% 
  mutate(judgment = paste("hole", judgment, sep = ""),
         model = "Sequential Sampler") %>% 
  rename(prediction = model_mean,
         hole = judgment)

df_fixed_judge = df_fixed_sample_long %>% 
  mutate(model = "Uniform Sampler")

df_to_show = rbind(df_bandit_judge,
                   df_fixed_judge) %>% 
  left_join(df_data_mean_judge_full)
Joining, by = c("trial", "hole")
df_sum_stat = df_to_show %>% 
  group_by(model) %>% 
  summarise(r = round(cor(prediction, human_mean), digits = 2),
            rmse = round(RMSE(prediction, human_mean), digits = 2))

ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
  geom_abline(slope = 1,
              intercept = 0,
              linetype = "dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.2) +
  geom_point(alpha = 0.2) +
  geom_smooth(method = "lm") + 
  geom_text(data = df_sum_stat,
            x = 0.0,
            y = 1,
            size = 6,
            hjust = 0,
            mapping = aes(label = paste("r: ", r, sep = ""))) +
  geom_text(data = df_sum_stat,
            x = 0.0,
            y = 0.93,
            size = 6,
            hjust = 0,
            mapping = aes(label = paste("rmse: ", rmse, sep = ""))) +
  facet_wrap(~ model) +
  scale_x_continuous(breaks = c(0.0, 0.25, 0.50, 0.75, 1.00),
                     labels = c("0%", "25%", "50%", "75%", "100%")) +
  scale_y_continuous(breaks = c(0.0, 0.25, 0.50, 0.75, 1.00),
                     labels = c("0%", "25%", "50%", "75%", "100%")) +
  xlab("Model Prediction") + 
  ylab("Participant Selection") +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=24),
        axis.text = element_text(size=16),
        panel.spacing = unit(2, "lines"))
`geom_smooth()` using formula 'y ~ x'
ggsave("figures/model_judgment.pdf",
       width = 10,
       height = 4)
`geom_smooth()` using formula 'y ~ x'

cor(df_bandit_judge$prediction, df_fixed_judge$prediction)
[1] 0.899042

EMD

df_emd_bandit = read.csv("../python/model/model_performance/emd/bandit.csv") %>% 
  select(trial, distance) %>% 
  mutate(model = "Sequential Sampler")

df_emd_fixed_sample = read.csv("../python/model/model_performance/emd/fixed_sample.csv") %>%
  select(trial, distance) %>%
  mutate(trial = factor(trial),
         model = "Uniform Sampler")

df_emd_baseline = read.csv("../python/model/model_performance/emd/visual_features.csv") %>% 
  select(-X) %>% 
  mutate(trial = factor(trial),
         model = "Visual Features")
to_highlight = c()

set.seed(1)

df_to_show = rbind(df_emd_bandit, df_emd_fixed_sample, df_emd_baseline) %>% 
  mutate(model = factor(model,
                        levels = c("Sequential Sampler", "Uniform Sampler", "Visual Features"),
                        labels = c(1,2,3)),
         model = as.numeric(as.character(model)),
         highlight = trial %in% to_highlight,
         model_jitter = model + runif(n = n(),
                                      min = -0.15,
                                      max = 0.15)) 

# ggplot(df_to_show, mapping = aes(x = model, y = distance)) +
ggplot(df_to_show, mapping = aes(x = model, 
                                 y = distance, 
                                 color = highlight)) +
  geom_line(mapping = aes(x = model_jitter, group = trial), 
            alpha = 0.05) +
  geom_point(mapping = aes(x = model_jitter),
             alpha = 0.5,
             shape=16,
             size=3) +
  stat_summary(fun.data = "mean_cl_boot", color = "red", size=0.8) +
  scale_x_continuous(breaks = c(1,2,3), labels = c("Sequential Sampler", "Uniform Sampler", "Visual Features")) + 
  scale_color_manual(values = c("black", "magenta3")) +
  ylab("Earth Mover's Distance") +
  theme(legend.title = element_blank(),
        legend.position = "none",
        axis.title.y = element_text(size=24),
        axis.title.x = element_blank(),
        axis.text = element_text(size=16))

ggsave("figures/emd_comparison.pdf",
       height = 5,
       width = 8)

df_emd = rbind(df_emd_bandit,
               df_emd_fixed_sample,
               df_emd_baseline)

df_emd %>% 
  group_by(model) %>% 
  do(data.frame(rbind(round(smean.cl.boot(.$distance), 2))))
# A tibble: 3 × 4
# Groups:   model [3]
  model               Mean Lower Upper
  <chr>              <dbl> <dbl> <dbl>
1 Sequential Sampler  49.2  46.9  51.5
2 Uniform Sampler     65.0  62.6  67.6
3 Visual Features     71.5  69.0  73.9